"""
Phase 3: CCA robustness tests for the gravity model.

Tests:
  3a. Drop all CCA pairs (either reporter or partner is CCA)
  3b. Leave-one-region-out jackknife
  3c. Extensive vs intensive margin decomposition

Output: gravity_bilateral/output/tables/gravity_robustness.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys

sys.path.insert(0, str(Path("/mnt/c/demographics_capital_flows/gravity_bilateral")))
from src.model import PanelGLS

# Paths
BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral")
PROCESSED_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"

# CCA countries (Caucasus & Central Asia) — the sensitivity group from followup paper
CCA_COUNTRIES = [
    'ARM', 'AZE', 'GEO', 'KAZ', 'KGZ', 'TJK', 'TKM', 'UZB',
]

# Region classification for jackknife
REGION_MAP = {
    'Advanced Europe': [
        'AUT', 'BEL', 'CHE', 'CYP', 'DEU', 'DNK', 'ESP', 'FIN', 'FRA',
        'GBR', 'GRC', 'IRL', 'ISL', 'ITA', 'LUX', 'MLT', 'NLD', 'NOR',
        'PRT', 'SWE',
    ],
    'EU New Members': [
        'BGR', 'CZE', 'EST', 'HRV', 'HUN', 'LTU', 'LVA', 'POL', 'ROU',
        'SVK', 'SVN',
    ],
    'East Asia': [
        'CHN', 'HKG', 'JPN', 'KOR', 'MNG', 'TWN',
    ],
    'Southeast Asia': [
        'IDN', 'KHM', 'LAO', 'MMR', 'MYS', 'PHL', 'SGP', 'THA', 'VNM',
    ],
    'South Asia': [
        'BGD', 'BTN', 'IND', 'LKA', 'NPL', 'PAK',
    ],
    'Latin America': [
        'ARG', 'BOL', 'BRA', 'CHL', 'COL', 'CRI', 'DOM', 'ECU', 'GTM',
        'HND', 'JAM', 'MEX', 'PER', 'PRY', 'URY', 'VEN',
    ],
    'Middle East & North Africa': [
        'ARE', 'BHR', 'DZA', 'EGY', 'IRN', 'IRQ', 'ISR', 'JOR', 'KWT',
        'LBN', 'MAR', 'OMN', 'QAT', 'SAU', 'TUN', 'YEM',
    ],
    'Sub-Saharan Africa': [
        'AGO', 'BDI', 'BEN', 'BFA', 'BWA', 'CAF', 'CIV', 'CMR', 'COD',
        'COM', 'CPV', 'ETH', 'GAB', 'GHA', 'GIN', 'GNB', 'GNQ', 'KEN',
        'LBR', 'LSO', 'MDG', 'MLI', 'MOZ', 'MUS', 'MWI', 'NAM', 'NGA',
        'RWA', 'SDN', 'SEN', 'SLE', 'SOM', 'SWZ', 'SYC', 'TCD', 'TGO',
        'TZA', 'UGA', 'ZAF', 'ZMB', 'ZWE',
    ],
    'CCA': CCA_COUNTRIES,
    'Other Europe & CIS': [
        'ALB', 'BIH', 'BLR', 'GEO', 'MDA', 'MKD', 'MNE', 'RUS', 'SRB',
        'TUR', 'UKR',
    ],
    'Anglo-Saxon & Pacific': [
        'AUS', 'CAN', 'NZL', 'USA',
    ],
}


def estimate_gravity_model(df, dep_var, regressors, model_name, year_dummies=True):
    """Estimate gravity model and return (model, results_df)."""
    years = sorted(df['year'].dropna().unique())
    yr_cols = [f'yr_{int(y)}' for y in years[1:]] if year_dummies else []
    yr_cols = [c for c in yr_cols if c in df.columns]

    all_vars = regressors + yr_cols
    est = df.dropna(subset=[dep_var] + all_vars + ['pair_id', 'year']).copy()

    if len(est) < 100:
        return None, None

    y = est[dep_var].values
    X = est[all_vars].values

    gls = PanelGLS()
    gls.fit(y, X, est['pair_id'].values, est['year'].values.astype(int))

    results = []
    for i, v in enumerate(regressors):
        results.append({
            'model': model_name,
            'variable': v,
            'coefficient': gls.beta[i],
            'std_error': gls.se[i],
            't_stat': gls.tvalues[i],
            'p_value': gls.pvalues[i],
        })
    # Add summary stats
    for stat_name, stat_val in [('_R_squared', gls.r_squared),
                                 ('_N_obs', gls.n_obs),
                                 ('_N_pairs', gls.n_countries),
                                 ('_rho', gls.rho)]:
        results.append({
            'model': model_name,
            'variable': stat_name,
            'coefficient': stat_val,
            'std_error': np.nan,
            't_stat': np.nan,
            'p_value': np.nan,
        })

    return gls, pd.DataFrame(results)


def main():
    print("=" * 70)
    print("PHASE 3: CCA ROBUSTNESS TESTS")
    print("=" * 70)

    # Load panel
    df = pd.read_csv(PROCESSED_DIR / "bilateral_panel.csv")

    # Create year dummies
    years = sorted(df['year'].dropna().unique())
    for y in years[1:]:
        df[f'yr_{int(y)}'] = (df['year'] == y).astype(int)

    # Determine dep var and regressors (same as Phase 2)
    dep_var = None
    for candidate in ['log_portfolio_total', 'log_portfolio_equity', 'log_fdi_outward']:
        if candidate in df.columns and df[candidate].notna().sum() > 500:
            dep_var = candidate
            break

    if dep_var is None:
        print("ERROR: No usable dependent variable!")
        return

    gravity_vars = [v for v in ['log_dist', 'contiguity', 'common_lang_official',
                                 'colonial_ties', 'log_gdp_product']
                    if v in df.columns and df[v].notna().sum() > 500]
    demo_vars = [v for v in ['dZ_1', 'dZ_2', 'dZ_3'] if v in df.columns]
    base_regressors = gravity_vars + demo_vars

    print(f"  Dep var: {dep_var}")
    print(f"  Gravity vars: {gravity_vars}")
    print(f"  Demo vars: {demo_vars}")

    all_results = []

    # === Reference: Full sample ===
    gls_full, res_full = estimate_gravity_model(
        df, dep_var, base_regressors, "Full Sample"
    )
    if res_full is not None:
        all_results.append(res_full)
        print(f"\n  Full sample: R²={gls_full.r_squared:.4f}, N={gls_full.n_obs:,}")
        for v in demo_vars:
            row = res_full[res_full['variable'] == v].iloc[0]
            sig = '***' if row['p_value'] < 0.01 else '**' if row['p_value'] < 0.05 else '*' if row['p_value'] < 0.1 else ''
            print(f"    {v}: {row['coefficient']:.4f} (p={row['p_value']:.4f}) {sig}")

    # ===================================================================
    # 3a. Drop all CCA pairs
    # ===================================================================
    print(f"\n{'=' * 70}")
    print("  3a. DROP CCA PAIRS")
    print(f"{'=' * 70}")

    cca_set = set(CCA_COUNTRIES)
    df_no_cca = df[~df['reporter'].isin(cca_set) & ~df['partner'].isin(cca_set)].copy()
    n_dropped = len(df) - len(df_no_cca)
    print(f"  Dropped {n_dropped:,} CCA-involved pairs ({n_dropped / len(df) * 100:.1f}%)")
    print(f"  Remaining: {len(df_no_cca):,} obs, {df_no_cca['pair_id'].nunique():,} pairs")

    gls_no_cca, res_no_cca = estimate_gravity_model(
        df_no_cca, dep_var, base_regressors, "3a: Excl CCA"
    )
    if res_no_cca is not None:
        all_results.append(res_no_cca)

        # Compare with full sample
        print(f"\n  CCA sensitivity comparison:")
        print(f"  {'Variable':<15} {'Full':>12} {'Excl CCA':>12} {'Δ%':>8}")
        for v in demo_vars:
            full_coef = res_full[res_full['variable'] == v].iloc[0]['coefficient']
            excl_coef = res_no_cca[res_no_cca['variable'] == v].iloc[0]['coefficient']
            excl_p = res_no_cca[res_no_cca['variable'] == v].iloc[0]['p_value']
            pct = (excl_coef - full_coef) / abs(full_coef) * 100 if full_coef != 0 else np.nan
            sig = '***' if excl_p < 0.01 else '**' if excl_p < 0.05 else '*' if excl_p < 0.1 else ''
            print(f"  {v:<15} {full_coef:>12.4f} {excl_coef:>12.4f} {pct:>7.1f}% {sig}")

    # Also test: drop CCA + non-commodity subset
    cca_non_commodity = ['ARM', 'GEO', 'KGZ', 'TJK', 'UZB', 'MDA', 'ALB', 'MKD']
    df_no_cca_nc = df[~df['reporter'].isin(cca_non_commodity) &
                       ~df['partner'].isin(cca_non_commodity)].copy()
    gls_no_cca_nc, res_no_cca_nc = estimate_gravity_model(
        df_no_cca_nc, dep_var, base_regressors, "3a: Excl CCA non-commodity"
    )
    if res_no_cca_nc is not None:
        all_results.append(res_no_cca_nc)

    # ===================================================================
    # 3b. Leave-one-region-out jackknife
    # ===================================================================
    print(f"\n{'=' * 70}")
    print("  3b. LEAVE-ONE-REGION-OUT JACKKNIFE")
    print(f"{'=' * 70}")

    jackknife_rows = []
    for region_name, region_countries in REGION_MAP.items():
        region_set = set(region_countries)
        df_excl = df[~df['reporter'].isin(region_set) & ~df['partner'].isin(region_set)].copy()

        if len(df_excl) < 500:
            print(f"  {region_name}: too few obs after dropping ({len(df_excl)}), skipping")
            continue

        gls_jk, res_jk = estimate_gravity_model(
            df_excl, dep_var, base_regressors, f"3b: Excl {region_name}"
        )
        if res_jk is not None:
            all_results.append(res_jk)

            for v in demo_vars:
                row = res_jk[res_jk['variable'] == v].iloc[0]
                jackknife_rows.append({
                    'excluded_region': region_name,
                    'variable': v,
                    'coefficient': row['coefficient'],
                    'p_value': row['p_value'],
                    'n_obs': gls_jk.n_obs,
                    'r_squared': gls_jk.r_squared,
                })

    if jackknife_rows:
        jk_df = pd.DataFrame(jackknife_rows)
        print(f"\n  Jackknife coefficient stability:")
        for v in demo_vars:
            vdf = jk_df[jk_df['variable'] == v]
            full_coef = res_full[res_full['variable'] == v].iloc[0]['coefficient']
            print(f"\n  {v} (full sample: {full_coef:.4f}):")
            print(f"    {'Excluded region':<30} {'Coef':>10} {'p-val':>8} {'Sig?':>5}")
            for _, row in vdf.iterrows():
                sig = '***' if row['p_value'] < 0.01 else '**' if row['p_value'] < 0.05 else '*' if row['p_value'] < 0.1 else ''
                print(f"    {row['excluded_region']:<30} {row['coefficient']:>10.4f} {row['p_value']:>8.4f} {sig}")
            print(f"    Range: [{vdf['coefficient'].min():.4f}, {vdf['coefficient'].max():.4f}]")
            print(f"    Stable: {'YES' if vdf['coefficient'].std() / abs(full_coef) < 0.5 else 'NO'}")

        jk_df.to_csv(OUTPUT_DIR / "jackknife_results.csv", index=False)

    # ===================================================================
    # 3c. Extensive vs intensive margin
    # ===================================================================
    print(f"\n{'=' * 70}")
    print("  3c. EXTENSIVE VS INTENSIVE MARGIN")
    print(f"{'=' * 70}")

    # Determine the base flow column
    flow_col = dep_var.replace('log_', '')

    # --- Extensive margin: logit on has_flow ---
    has_col = f'has_{flow_col}'
    if has_col in df.columns:
        print(f"\n  Extensive margin: P({has_col} = 1)")
        import statsmodels.api as sm

        ext_regressors = base_regressors.copy()
        ext_df = df.dropna(subset=[has_col] + ext_regressors).copy()

        if len(ext_df) > 500:
            y_ext = ext_df[has_col].values
            X_ext = sm.add_constant(ext_df[ext_regressors].values)

            try:
                logit = sm.Logit(y_ext, X_ext).fit(disp=0)
                print(f"  Logit: N = {len(ext_df):,}, Pseudo R² = {logit.prsquared:.4f}")
                print(f"  {'Variable':<30} {'Coef':>10} {'SE':>10} {'p-val':>8}")
                print(f"  {'-' * 60}")
                for i, v in enumerate(ext_regressors):
                    sig = '***' if logit.pvalues[i + 1] < 0.01 else '**' if logit.pvalues[i + 1] < 0.05 else '*' if logit.pvalues[i + 1] < 0.1 else ''
                    print(f"  {v:<30} {logit.params[i + 1]:>10.4f} {logit.bse[i + 1]:>10.4f} {logit.pvalues[i + 1]:>8.4f} {sig}")

                # Save results
                for i, v in enumerate(ext_regressors):
                    all_results.append(pd.DataFrame([{
                        'model': '3c: Extensive (Logit)',
                        'variable': v,
                        'coefficient': logit.params[i + 1],
                        'std_error': logit.bse[i + 1],
                        't_stat': logit.tvalues[i + 1],
                        'p_value': logit.pvalues[i + 1],
                    }]))
                all_results.append(pd.DataFrame([{
                    'model': '3c: Extensive (Logit)',
                    'variable': '_Pseudo_R_squared',
                    'coefficient': logit.prsquared,
                    'std_error': np.nan,
                    't_stat': np.nan,
                    'p_value': np.nan,
                }]))
                all_results.append(pd.DataFrame([{
                    'model': '3c: Extensive (Logit)',
                    'variable': '_N_obs',
                    'coefficient': len(ext_df),
                    'std_error': np.nan,
                    't_stat': np.nan,
                    'p_value': np.nan,
                }]))

            except Exception as e:
                print(f"  Logit failed: {e}")

    # --- Intensive margin: OLS on log(flow) conditional on flow > 0 ---
    print(f"\n  Intensive margin: {dep_var} | flow > 0")
    df_intensive = df[df[flow_col] > 0].copy() if flow_col in df.columns else df.copy()
    gls_int, res_int = estimate_gravity_model(
        df_intensive, dep_var, base_regressors, "3c: Intensive (flow > 0)"
    )
    if res_int is not None:
        all_results.append(res_int)
        print(f"  Intensive margin: N = {gls_int.n_obs:,}, R² = {gls_int.r_squared:.4f}")

    # ===================================================================
    # Save all results
    # ===================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        outfile = OUTPUT_DIR / "gravity_robustness.csv"
        results_df.to_csv(outfile, index=False)
        print(f"\n  Saved all robustness results: {outfile}")

        # Final summary
        print(f"\n{'=' * 70}")
        print("ROBUSTNESS SUMMARY: ΔZ coefficients across specifications")
        print(f"{'=' * 70}")

        summary_rows = []
        for v in demo_vars:
            row = {'variable': v}
            for model_name in results_df['model'].unique():
                mdf = results_df[(results_df['model'] == model_name) &
                                  (results_df['variable'] == v)]
                if len(mdf) > 0:
                    coef = mdf.iloc[0]['coefficient']
                    pval = mdf.iloc[0]['p_value']
                    sig = '***' if pval < 0.01 else '**' if pval < 0.05 else '*' if pval < 0.1 else ''
                    row[model_name] = f"{coef:.4f}{sig}"
            summary_rows.append(row)

        summary = pd.DataFrame(summary_rows)
        print(summary.to_string(index=False))

    return results_df if all_results else None


if __name__ == "__main__":
    results = main()
